import multiprocessing
import os
import threading
import time

from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING

from sqlalchemy.engine import Engine

from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig
from memos.context.context import ContextThread
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.base import BaseMemCube
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.schemas.general_schemas import (
    DEFAULT_ACT_MEM_DUMP_PATH,
    DEFAULT_CONSUME_BATCH,
    DEFAULT_CONSUME_INTERVAL_SECONDS,
    DEFAULT_CONTEXT_WINDOW_SIZE,
    DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
    DEFAULT_MAX_WEB_LOG_QUEUE_SIZE,
    DEFAULT_STARTUP_MODE,
    DEFAULT_THREAD_POOL_MAX_WORKERS,
    DEFAULT_TOP_K,
    DEFAULT_USE_REDIS_QUEUE,
    STARTUP_BY_PROCESS,
    MemCubeID,
    TreeTextMemory_SEARCH_METHOD,
    UserID,
)
from memos.mem_scheduler.schemas.message_schemas import (
    ScheduleLogForWebItem,
    ScheduleMessageItem,
)
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
from memos.mem_scheduler.utils.db_utils import get_utc_now
from memos.mem_scheduler.utils.filter_utils import (
    transform_name_to_key,
)
from memos.mem_scheduler.webservice_modules.rabbitmq_service import RabbitMQSchedulerModule
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
from memos.memories.activation.kv import KVCacheMemory
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE


if TYPE_CHECKING:
    from memos.reranker.http_bge import HTTPBGEReranker


logger = get_logger(__name__)


class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule):
    """Base class for all mem_scheduler."""

    def __init__(self, config: BaseSchedulerConfig):
        """Initialize the scheduler with the given configuration."""
        super().__init__()
        self.config = config

        # hyper-parameters
        self.top_k = self.config.get("top_k", DEFAULT_TOP_K)
        self.context_window_size = self.config.get(
            "context_window_size", DEFAULT_CONTEXT_WINDOW_SIZE
        )
        self.enable_activation_memory = self.config.get("enable_activation_memory", False)
        self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
        self.search_method = self.config.get("search_method", TreeTextMemory_SEARCH_METHOD)
        self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", True)
        self.thread_pool_max_workers = self.config.get(
            "thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS
        )

        # startup mode configuration
        self.scheduler_startup_mode = self.config.get(
            "scheduler_startup_mode", DEFAULT_STARTUP_MODE
        )

        # optional configs
        self.disabled_handlers: list | None = self.config.get("disabled_handlers", None)

        self.max_web_log_queue_size = self.config.get(
            "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
        )
        self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue(
            maxsize=self.max_web_log_queue_size
        )
        self._consumer_thread = None  # Reference to our consumer thread/process
        self._consumer_process = None  # Reference to our consumer process
        self._running = False
        self._consume_interval = self.config.get(
            "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS
        )
        self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH)

        # message queue configuration
        self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE)
        self.max_internal_message_queue_size = self.config.get(
            "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE
        )
        self.memos_message_queue = ScheduleTaskQueue(
            use_redis_queue=self.use_redis_queue,
            maxsize=self.max_internal_message_queue_size,
            disabled_handlers=self.disabled_handlers,
        )
        self.searcher: Searcher | None = None
        self.retriever: SchedulerRetriever | None = None
        self.db_engine: Engine | None = None
        self.monitor: SchedulerGeneralMonitor | None = None
        self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None
        self.mem_reader = None  # Will be set by MOSCore
        self.dispatcher = SchedulerDispatcher(
            config=self.config,
            memos_message_queue=self.memos_message_queue,
            use_redis_queue=self.use_redis_queue,
            max_workers=self.thread_pool_max_workers,
            enable_parallel_dispatch=self.enable_parallel_dispatch,
        )

        # other attributes
        self._context_lock = threading.Lock()
        self.current_user_id: UserID | str | None = None
        self.current_mem_cube_id: MemCubeID | str | None = None
        self.current_mem_cube: BaseMemCube | None = None
        self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None)
        self.auth_config = None
        self.rabbitmq_config = None

    def init_mem_cube(
        self,
        mem_cube: BaseMemCube,
        searcher: Searcher | None = None,
    ):
        self.mem_cube = mem_cube
        self.text_mem: TreeTextMemory = self.mem_cube.text_mem
        self.reranker: HTTPBGEReranker = self.text_mem.reranker
        if searcher is None:
            self.searcher: Searcher = self.text_mem.get_searcher(
                manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
                moscube=False,
            )
        else:
            self.searcher = searcher

    def initialize_modules(
        self,
        chat_llm: BaseLLM,
        process_llm: BaseLLM | None = None,
        db_engine: Engine | None = None,
        mem_reader=None,
    ):
        if process_llm is None:
            process_llm = chat_llm

        try:
            # initialize submodules
            self.chat_llm = chat_llm
            self.process_llm = process_llm
            self.db_engine = db_engine
            self.monitor = SchedulerGeneralMonitor(
                process_llm=self.process_llm, config=self.config, db_engine=self.db_engine
            )
            self.db_engine = self.monitor.db_engine
            self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config)
            self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)

            if mem_reader:
                self.mem_reader = mem_reader

            if self.enable_parallel_dispatch:
                self.dispatcher_monitor.initialize(dispatcher=self.dispatcher)
                self.dispatcher_monitor.start()

            # initialize with auth_config
            try:
                if self.auth_config_path is not None and Path(self.auth_config_path).exists():
                    self.auth_config = AuthConfig.from_local_config(
                        config_path=self.auth_config_path
                    )
                elif AuthConfig.default_config_exists():
                    self.auth_config = AuthConfig.from_local_config()
                else:
                    self.auth_config = AuthConfig.from_local_env()
            except Exception:
                pass

            if self.auth_config is not None:
                self.rabbitmq_config = self.auth_config.rabbitmq
                if self.rabbitmq_config is not None:
                    self.initialize_rabbitmq(config=self.rabbitmq_config)

            logger.debug("GeneralScheduler has been initialized")
        except Exception as e:
            logger.error(f"Failed to initialize scheduler modules: {e}", exc_info=True)
            # Clean up any partially initialized resources
            self._cleanup_on_init_failure()
            raise

        # start queue monitor if enabled and a bot is set later

    def debug_mode_on(self):
        self.memos_message_queue.debug_mode_on()

    def _cleanup_on_init_failure(self):
        """Clean up resources if initialization fails."""
        try:
            if hasattr(self, "dispatcher_monitor") and self.dispatcher_monitor is not None:
                self.dispatcher_monitor.stop()
        except Exception as e:
            logger.warning(f"Error during cleanup: {e}")

    @property
    def mem_cube(self) -> BaseMemCube:
        """The memory cube associated with this MemChat."""
        return self.current_mem_cube

    @mem_cube.setter
    def mem_cube(self, value: BaseMemCube) -> None:
        """The memory cube associated with this MemChat."""
        self.current_mem_cube = value
        self.retriever.mem_cube = value

    def transform_working_memories_to_monitors(
        self, query_keywords, memories: list[TextualMemoryItem]
    ) -> list[MemoryMonitorItem]:
        """
        Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects
        with importance scores based on keyword matching.

        Args:
            memories: List of TextualMemoryItem objects to be transformed.

        Returns:
            List of MemoryMonitorItem objects with computed importance scores.
        """

        result = []
        mem_length = len(memories)
        for idx, mem in enumerate(memories):
            text_mem = mem.memory
            mem_key = transform_name_to_key(name=text_mem)

            # Calculate importance score based on keyword matches
            keywords_score = 0
            if query_keywords and text_mem:
                for keyword, count in query_keywords.items():
                    keyword_count = text_mem.count(keyword)
                    if keyword_count > 0:
                        keywords_score += keyword_count * count
                        logger.debug(
                            f"Matched keyword '{keyword}' {keyword_count} times, added {keywords_score} to keywords_score"
                        )

            # rank score
            sorting_score = mem_length - idx

            mem_monitor = MemoryMonitorItem(
                memory_text=text_mem,
                tree_memory_item=mem,
                tree_memory_item_mapping_key=mem_key,
                sorting_score=sorting_score,
                keywords_score=keywords_score,
                recording_count=1,
            )
            result.append(mem_monitor)

        logger.info(f"Transformed {len(result)} memories to monitors")
        return result

    def replace_working_memory(
        self,
        user_id: UserID | str,
        mem_cube_id: MemCubeID | str,
        mem_cube: GeneralMemCube,
        original_memory: list[TextualMemoryItem],
        new_memory: list[TextualMemoryItem],
    ) -> None | list[TextualMemoryItem]:
        """Replace working memory with new memories after reranking."""
        text_mem_base = mem_cube.text_mem
        if isinstance(text_mem_base, TreeTextMemory):
            text_mem_base: TreeTextMemory = text_mem_base

            # process rerank memories with llm
            query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
            # Sync with database to get latest query history
            query_db_manager.sync_with_orm()

            query_history = query_db_manager.obj.get_queries_with_timesort()
            memories_with_new_order, rerank_success_flag = (
                self.retriever.process_and_rerank_memories(
                    queries=query_history,
                    original_memory=original_memory,
                    new_memory=new_memory,
                    top_k=self.top_k,
                )
            )

            # Filter completely unrelated memories according to query_history
            logger.info(f"Filtering memories based on query history: {len(query_history)} queries")
            filtered_memories, filter_success_flag = self.retriever.filter_unrelated_memories(
                query_history=query_history,
                memories=memories_with_new_order,
            )

            if filter_success_flag:
                logger.info(
                    f"Memory filtering completed successfully. "
                    f"Filtered from {len(memories_with_new_order)} to {len(filtered_memories)} memories"
                )
                memories_with_new_order = filtered_memories
            else:
                logger.warning(
                    "Memory filtering failed - keeping all memories as fallback. "
                    f"Original count: {len(memories_with_new_order)}"
                )

            # Update working memory monitors
            query_keywords = query_db_manager.obj.get_keywords_collections()
            logger.info(
                f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
            )
            new_working_memory_monitors = self.transform_working_memories_to_monitors(
                query_keywords=query_keywords,
                memories=memories_with_new_order,
            )

            if not rerank_success_flag:
                for one in new_working_memory_monitors:
                    one.sorting_score = 0

            logger.info(f"update {len(new_working_memory_monitors)} working_memory_monitors")
            self.monitor.update_working_memory_monitors(
                new_working_memory_monitors=new_working_memory_monitors,
                user_id=user_id,
                mem_cube_id=mem_cube_id,
                mem_cube=mem_cube,
            )

            mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][
                mem_cube_id
            ].obj.get_sorted_mem_monitors(reverse=True)
            new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors]

            text_mem_base.replace_working_memory(memories=new_working_memories)

            logger.info(
                f"The working memory has been replaced with {len(memories_with_new_order)} new memories."
            )
            self.log_working_memory_replacement(
                original_memory=original_memory,
                new_memory=new_working_memories,
                user_id=user_id,
                mem_cube_id=mem_cube_id,
                mem_cube=mem_cube,
                log_func_callback=self._submit_web_logs,
            )
        else:
            logger.error("memory_base is not supported")
            memories_with_new_order = new_memory

        return memories_with_new_order

    def update_activation_memory(
        self,
        new_memories: list[str | TextualMemoryItem],
        label: str,
        user_id: UserID | str,
        mem_cube_id: MemCubeID | str,
        mem_cube: GeneralMemCube,
    ) -> None:
        """
        Update activation memory by extracting KVCacheItems from new_memory (list of str),
        add them to a KVCacheMemory instance, and dump to disk.
        """
        if len(new_memories) == 0:
            logger.error("update_activation_memory: new_memory is empty.")
            return
        if isinstance(new_memories[0], TextualMemoryItem):
            new_text_memories = [mem.memory for mem in new_memories]
        elif isinstance(new_memories[0], str):
            new_text_memories = new_memories
        else:
            logger.error("Not Implemented.")
            return

        try:
            if isinstance(mem_cube.act_mem, VLLMKVCacheMemory):
                act_mem: VLLMKVCacheMemory = mem_cube.act_mem
            elif isinstance(mem_cube.act_mem, KVCacheMemory):
                act_mem: KVCacheMemory = mem_cube.act_mem
            else:
                logger.error("Not Implemented.")
                return

            new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
                memory_text="".join(
                    [
                        f"{i + 1}. {sentence.strip()}\n"
                        for i, sentence in enumerate(new_text_memories)
                        if sentence.strip()  # Skip empty strings
                    ]
                )
            )

            # huggingface or vllm kv cache
            original_cache_items: list[VLLMKVCacheItem] = act_mem.get_all()
            original_text_memories = []
            if len(original_cache_items) > 0:
                pre_cache_item: VLLMKVCacheItem = original_cache_items[-1]
                original_text_memories = pre_cache_item.records.text_memories
                original_composed_text_memory = pre_cache_item.records.composed_text_memory
                if original_composed_text_memory == new_text_memory:
                    logger.warning(
                        "Skipping memory update - new composition matches existing cache: %s",
                        new_text_memory[:50] + "..."
                        if len(new_text_memory) > 50
                        else new_text_memory,
                    )
                    return
                act_mem.delete_all()

            cache_item = act_mem.extract(new_text_memory)
            cache_item.records.text_memories = new_text_memories
            cache_item.records.timestamp = get_utc_now()

            act_mem.add([cache_item])
            act_mem.dump(self.act_mem_dump_path)

            self.log_activation_memory_update(
                original_text_memories=original_text_memories,
                new_text_memories=new_text_memories,
                label=label,
                user_id=user_id,
                mem_cube_id=mem_cube_id,
                mem_cube=mem_cube,
                log_func_callback=self._submit_web_logs,
            )

        except Exception as e:
            logger.error(f"MOS-based activation memory update failed: {e}", exc_info=True)
            # Re-raise the exception if it's critical for the operation
            # For now, we'll continue execution but this should be reviewed

    def update_activation_memory_periodically(
        self,
        interval_seconds: int,
        label: str,
        user_id: UserID | str,
        mem_cube_id: MemCubeID | str,
        mem_cube: GeneralMemCube,
    ):
        try:
            if (
                self.monitor.last_activation_mem_update_time == datetime.min
                or self.monitor.timed_trigger(
                    last_time=self.monitor.last_activation_mem_update_time,
                    interval_seconds=interval_seconds,
                )
            ):
                logger.info(
                    f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}"
                )

                if (
                    user_id not in self.monitor.working_memory_monitors
                    or mem_cube_id not in self.monitor.working_memory_monitors[user_id]
                    or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].obj.memories)
                    == 0
                ):
                    logger.warning(
                        "No memories found in working_memory_monitors, activation memory update is skipped"
                    )
                    return

                self.monitor.update_activation_memory_monitors(
                    user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
                )

                # Sync with database to get latest activation memories
                activation_db_manager = self.monitor.activation_memory_monitors[user_id][
                    mem_cube_id
                ]
                activation_db_manager.sync_with_orm()
                new_activation_memories = [
                    m.memory_text for m in activation_db_manager.obj.memories
                ]

                logger.info(
                    f"Collected {len(new_activation_memories)} new memory entries for processing"
                )
                # Print the content of each new activation memory
                for i, memory in enumerate(new_activation_memories[:5], 1):
                    logger.info(
                        f"Part of New Activation Memorires | {i}/{len(new_activation_memories)}: {memory[:20]}"
                    )

                self.update_activation_memory(
                    new_memories=new_activation_memories,
                    label=label,
                    user_id=user_id,
                    mem_cube_id=mem_cube_id,
                    mem_cube=mem_cube,
                )

                self.monitor.last_activation_mem_update_time = get_utc_now()

                logger.debug(
                    f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
                )

            else:
                logger.info(
                    f"Skipping update - {interval_seconds} second interval not yet reached. "
                    f"Last update time is {self.monitor.last_activation_mem_update_time} and now is "
                    f"{get_utc_now()}"
                )
        except Exception as e:
            logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True)

    def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
        self.memos_message_queue.submit_messages(messages=messages)

    def _submit_web_logs(
        self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem]
    ) -> None:
        """Submit log messages to the web log queue and optionally to RabbitMQ.

        Args:
            messages: Single log message or list of log messages
        """
        if self.rabbitmq_config is None:
            return

        if isinstance(messages, ScheduleLogForWebItem):
            messages = [messages]  # transform single message to list

        for message in messages:
            if not isinstance(message, ScheduleLogForWebItem):
                error_msg = f"Invalid message type: {type(message)}, expected ScheduleLogForWebItem"
                logger.error(error_msg)
                raise TypeError(error_msg)

            self._web_log_message_queue.put(message)
            message_info = message.debug_info()
            logger.debug(f"Submitted Scheduling log for web: {message_info}")

            if self.is_rabbitmq_connected():
                logger.info(f"Submitted Scheduling log to rabbitmq: {message_info}")
                self.rabbitmq_publish_message(message=message.to_dict())
        logger.debug(f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue.")

    def get_web_log_messages(self) -> list[dict]:
        """
        Retrieves all web log messages from the queue and returns them as a list of JSON-serializable dictionaries.

        Returns:
            List[dict]: A list of dictionaries representing ScheduleLogForWebItem objects,
                       ready for JSON serialization. The list is ordered from oldest to newest.
        """
        messages = []
        while True:
            try:
                item = self._web_log_message_queue.get_nowait()  # Thread-safe get
                messages.append(item.to_dict())
            except Exception:
                break
        return messages

    def _message_consumer(self) -> None:
        """
        Continuously checks the queue for messages and dispatches them.

        Runs in a dedicated thread to process messages at regular intervals.
        For Redis queue, this method starts the Redis listener.
        """

        # Original local queue logic
        while self._running:  # Use a running flag for graceful shutdown
            try:
                # Get messages in batches based on consume_batch setting

                messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch)

                if messages:
                    try:
                        import contextlib

                        with contextlib.suppress(Exception):
                            if messages:
                                self.dispatcher.on_messages_enqueued(messages)

                        self.dispatcher.dispatch(messages)
                    except Exception as e:
                        logger.error(f"Error dispatching messages: {e!s}")

                # Sleep briefly to prevent busy waiting
                time.sleep(self._consume_interval)  # Adjust interval as needed

            except Exception as e:
                # Don't log error for "No messages available in Redis queue" as it's expected
                if "No messages available in Redis queue" not in str(e):
                    logger.error(f"Unexpected error in message consumer: {e!s}")
                time.sleep(self._consume_interval)  # Prevent tight error loops

    def start(self) -> None:
        """
        Start the message consumer thread/process and initialize dispatcher resources.

        Initializes and starts:
        1. Message consumer thread or process (based on startup_mode)
        2. Dispatcher thread pool (if parallel dispatch enabled)
        """
        # Initialize dispatcher resources
        if self.enable_parallel_dispatch:
            logger.info(
                f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers"
            )

        self.start_consumer()

    def start_consumer(self) -> None:
        """
        Start only the message consumer thread/process.

        This method can be used to restart the consumer after it has been stopped
        with stop_consumer(), without affecting other scheduler components.
        """
        if self._running:
            logger.warning("Memory Scheduler consumer is already running")
            return

        # Start consumer based on startup mode
        self._running = True

        if self.scheduler_startup_mode == STARTUP_BY_PROCESS:
            # Start consumer process
            self._consumer_process = multiprocessing.Process(
                target=self._message_consumer,
                daemon=True,
                name="MessageConsumerProcess",
            )
            self._consumer_process.start()
            logger.info("Message consumer process started")
        else:
            # Default to thread mode
            self._consumer_thread = ContextThread(
                target=self._message_consumer,
                daemon=True,
                name="MessageConsumerThread",
            )
            self._consumer_thread.start()
            logger.info("Message consumer thread started")

    def stop_consumer(self) -> None:
        """Stop only the message consumer thread/process gracefully.

        This method stops the consumer without affecting other components like
        dispatcher or monitors. Useful when you want to pause message processing
        while keeping other scheduler components running.
        """
        if not self._running:
            logger.warning("Memory Scheduler consumer is not running")
            return

        # Signal consumer thread/process to stop
        self._running = False

        # Wait for consumer thread or process
        if self.scheduler_startup_mode == STARTUP_BY_PROCESS and self._consumer_process:
            if self._consumer_process.is_alive():
                self._consumer_process.join(timeout=5.0)
                if self._consumer_process.is_alive():
                    logger.warning("Consumer process did not stop gracefully, terminating...")
                    self._consumer_process.terminate()
                    self._consumer_process.join(timeout=2.0)
                    if self._consumer_process.is_alive():
                        logger.error("Consumer process could not be terminated")
                    else:
                        logger.info("Consumer process terminated")
                else:
                    logger.info("Consumer process stopped")
            self._consumer_process = None
        elif self._consumer_thread and self._consumer_thread.is_alive():
            self._consumer_thread.join(timeout=5.0)
            if self._consumer_thread.is_alive():
                logger.warning("Consumer thread did not stop gracefully")
            else:
                logger.info("Consumer thread stopped")
            self._consumer_thread = None

        logger.info("Memory Scheduler consumer stopped")

    def stop(self) -> None:
        """Stop all scheduler components gracefully.

        1. Stops message consumer thread/process
        2. Shuts down dispatcher thread pool
        3. Cleans up resources
        """
        if not self._running:
            logger.warning("Memory Scheduler is not running")
            return

        # Stop consumer first
        self.stop_consumer()

        # Shutdown dispatcher
        if self.dispatcher:
            logger.info("Shutting down dispatcher...")
            self.dispatcher.shutdown()

        # Shutdown dispatcher_monitor
        if self.dispatcher_monitor:
            logger.info("Shutting down monitor...")
            self.dispatcher_monitor.stop()

    @property
    def handlers(self) -> dict[str, Callable]:
        """
        Access the dispatcher's handlers dictionary.

        Returns:
            dict[str, Callable]: Dictionary mapping labels to handler functions
        """
        if not self.dispatcher:
            logger.warning("Dispatcher is not initialized, returning empty handlers dict")
            return {}

        return self.dispatcher.handlers

    def register_handlers(
        self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]
    ) -> None:
        """
        Bulk register multiple handlers from a dictionary.

        Args:
            handlers: Dictionary mapping labels to handler functions
                      Format: {label: handler_callable}
        """
        if not self.dispatcher:
            logger.warning("Dispatcher is not initialized, cannot register handlers")
            return

        self.dispatcher.register_handlers(handlers)

    def unregister_handlers(self, labels: list[str]) -> dict[str, bool]:
        """
        Unregister handlers from the dispatcher by their labels.

        Args:
            labels: List of labels to unregister handlers for

        Returns:
            dict[str, bool]: Dictionary mapping each label to whether it was successfully unregistered
        """
        if not self.dispatcher:
            logger.warning("Dispatcher is not initialized, cannot unregister handlers")
            return dict.fromkeys(labels, False)

        return self.dispatcher.unregister_handlers(labels)

    def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]:
        if not self.dispatcher:
            logger.warning("Dispatcher is not initialized, returning empty tasks dict")
            return {}

        running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func)

        # Convert RunningTaskItem objects to dictionaries for easier consumption
        result = {}
        for task_id, task_item in running_tasks.items():
            result[task_id] = {
                "item_id": task_item.item_id,
                "user_id": task_item.user_id,
                "mem_cube_id": task_item.mem_cube_id,
                "task_info": task_item.task_info,
                "task_name": task_item.task_name,
                "start_time": task_item.start_time,
                "end_time": task_item.end_time,
                "status": task_item.status,
                "result": task_item.result,
                "error_message": task_item.error_message,
                "messages": task_item.messages,
            }

        return result

    def mem_scheduler_wait(
        self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01
    ) -> bool:
        """
        Uses EWMA throughput, detects leaked `unfinished_tasks`, and waits for dispatcher.
        """
        deadline = time.monotonic() + timeout

        # --- helpers (local, no external deps) ---
        def _unfinished() -> int:
            """Prefer `unfinished_tasks`; fallback to `qsize()`."""
            try:
                u = getattr(self.memos_message_queue, "unfinished_tasks", None)
                if u is not None:
                    return int(u)
            except Exception:
                pass
            try:
                return int(self.memos_message_queue.qsize())
            except Exception:
                return 0

        def _fmt_eta(seconds: float | None) -> str:
            """Format seconds to human-readable string."""
            if seconds is None or seconds != seconds or seconds == float("inf"):
                return "unknown"
            s = max(0, int(seconds))
            h, s = divmod(s, 3600)
            m, s = divmod(s, 60)
            if h > 0:
                return f"{h:d}h{m:02d}m{s:02d}s"
            if m > 0:
                return f"{m:d}m{s:02d}s"
            return f"{s:d}s"

        # --- EWMA throughput state (tasks/s) ---
        alpha = 0.3
        rate = 0.0
        last_t = None  # type: float | None
        last_done = 0

        # --- dynamic totals & stuck detection ---
        init_unfinished = _unfinished()
        done_total = 0
        last_unfinished = None
        stuck_ticks = 0
        next_log = 0.0

        while True:
            # 1) read counters
            curr_unfinished = _unfinished()
            try:
                qsz = int(self.memos_message_queue.qsize())
            except Exception:
                qsz = -1

            pend = run = 0
            stats_fn = getattr(self.dispatcher, "stats", None)
            if self.enable_parallel_dispatch and self.dispatcher is not None and callable(stats_fn):
                try:
                    st = (
                        stats_fn()
                    )  # expected: {'pending':int,'running':int,'done':int?,'rate':float?}
                    run = int(st.get("running", 0))

                except Exception:
                    pass

            if isinstance(self.memos_message_queue, SchedulerRedisQueue):
                # For Redis queue, prefer XINFO GROUPS to compute pending
                groups_info = self.memos_message_queue.redis.xinfo_groups(
                    self.memos_message_queue.stream_key_prefix
                )
                if groups_info:
                    for group in groups_info:
                        if group.get("name") == self.memos_message_queue.consumer_group:
                            pend = int(group.get("pending", pend))
                            break
            else:
                pend = run

            # 2) dynamic total (allows new tasks queued while waiting)
            total_now = max(init_unfinished, done_total + curr_unfinished)
            done_total = max(0, total_now - curr_unfinished)

            # 3) update EWMA throughput
            now = time.monotonic()
            if last_t is None:
                last_t = now
            else:
                dt = max(1e-6, now - last_t)
                dc = max(0, done_total - last_done)
                inst = dc / dt
                rate = inst if rate == 0.0 else alpha * inst + (1 - alpha) * rate
                last_t = now
                last_done = done_total

            eta = None if rate <= 1e-9 else (curr_unfinished / rate)

            # 4) progress log (throttled)
            if now >= next_log:
                print(
                    f"[mem_scheduler_wait] remaining≈{curr_unfinished} | throughput≈{rate:.2f} msg/s | ETA≈{_fmt_eta(eta)} "
                    f"| qsize={qsz} pending={pend} running={run}"
                )
                next_log = now + max(0.2, log_every)

            # 5) exit / stuck detection
            idle_dispatcher = (
                (pend == 0 and run == 0)
                if (self.enable_parallel_dispatch and self.dispatcher is not None)
                else True
            )
            if curr_unfinished == 0:
                break
            if curr_unfinished > 0 and qsz == 0 and idle_dispatcher:
                if last_unfinished == curr_unfinished:
                    stuck_ticks += 1
                else:
                    stuck_ticks = 0
            else:
                stuck_ticks = 0
            last_unfinished = curr_unfinished

            if stuck_ticks >= 3:
                logger.warning(
                    "mem_scheduler_wait: detected leaked 'unfinished_tasks' -> treating queue as drained"
                )
                break

            if now >= deadline:
                logger.warning("mem_scheduler_wait: queue did not drain before timeout")
                return False

            time.sleep(poll)

        # 6) wait dispatcher (second stage)
        remaining = max(0.0, deadline - time.monotonic())
        if self.enable_parallel_dispatch and self.dispatcher is not None:
            try:
                ok = self.dispatcher.join(timeout=remaining if remaining > 0 else 0)
            except TypeError:
                ok = self.dispatcher.join()
            if not ok:
                logger.warning("mem_scheduler_wait: dispatcher did not complete before timeout")
                return False

        return True

    def _gather_queue_stats(self) -> dict:
        """Collect queue/dispatcher stats for reporting."""
        stats: dict[str, int | float | str] = {}
        stats["use_redis_queue"] = bool(self.use_redis_queue)
        # local queue metrics
        if not self.use_redis_queue:
            try:
                stats["qsize"] = int(self.memos_message_queue.qsize())
            except Exception:
                stats["qsize"] = -1
            # unfinished_tasks if available
            try:
                stats["unfinished_tasks"] = int(
                    getattr(self.memos_message_queue, "unfinished_tasks", 0) or 0
                )
            except Exception:
                stats["unfinished_tasks"] = -1
            stats["maxsize"] = int(self.max_internal_message_queue_size)
            try:
                maxsize = int(self.max_internal_message_queue_size) or 1
                qsize = int(stats.get("qsize", 0))
                stats["utilization"] = min(1.0, max(0.0, qsize / maxsize))
            except Exception:
                stats["utilization"] = 0.0
        # dispatcher stats
        try:
            d_stats = self.dispatcher.stats()
            stats.update(
                {
                    "running": int(d_stats.get("running", 0)),
                    "inflight": int(d_stats.get("inflight", 0)),
                    "handlers": int(d_stats.get("handlers", 0)),
                }
            )
        except Exception:
            stats.update({"running": 0, "inflight": 0, "handlers": 0})
        return stats
